import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class GraphAttentionLayer(nn.Module):
    """
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
    """

    def __init__(self, in_features, out_features, dropout=False, alpha=0.2, concat=False):
        """
        Dense version of GAT.
        :param in_features: 输入特征的维度
        :param out_features:  输出特征的维度
        :param dropout: dropout
         param alpha Relu的参数一般为0.2
        :param alpha: LeakyRelu中的参数
        """

        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat
        # print(in_features)
        self.W = nn.Parameter(torch.zeros(in_features, out_features))
        # self.W = nn.Parameter(torch.zeros(128, 32))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        # self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
        self.a = nn.Parameter(torch.zeros(2*out_features, 1))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        self.leakyrelu = nn.LeakyReLU(self.alpha)
        # input输入尺寸为（B,N,obs_dim）,adj:为（B,N,N）

    def forward(self, input, adj):
        # print(B,N,C)
        h = torch.matmul(input, self.W)
        B, N, C = h.size() # B N out_features
        # repeat()括号内得参数乘以张量的size即为得到的结果，torch.cat([],dim=1),dim的参数即为对第dim的维度进行一个扩张。
        # (B,N,C) (B,N,N*C),(B,N*N,C),(B,N*N,2C)
        a_input = torch.cat([h.repeat(1, 1, N).view(B, N * N, C), h.repeat(1, N, 1)], dim=2).view(B, N, N,
                                                                                                  2 * self.out_features)  # [B,N,N,2C]
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(3))     # (B,N,N),为每个批次环境中的邻居的相互关系eij
        # print(self.a.shape)

        zero_vec = -9e15*torch.ones_like(e)
        # （condition, the value in the condition, 0）
        attention = torch.where(adj < 1, e, zero_vec)      #1024个样本，就有1024个adj，那么就有1024个样本的8个agent之间的attention系数 那么就得和对应的1024个8个agent的样本输入相乘  attention格式应该为1024,8,8

        attention = F.softmax(attention, dim=2)#B,N,N

        h_prime = torch.matmul(attention, h)#(B N N B N C)=B N C 

        if self.concat:
            return F.elu(h_prime), attention
        else:
            return h_prime

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'